Skip to content

Add AG/RS overlap distributed init support#2487

Open
jeffnvidia wants to merge 2 commits intoNVIDIA-NeMo:mainfrom
jeffnvidia:ag_rs_overlap_process_group_collection
Open

Add AG/RS overlap distributed init support#2487
jeffnvidia wants to merge 2 commits intoNVIDIA-NeMo:mainfrom
jeffnvidia:ag_rs_overlap_process_group_collection

Conversation

@jeffnvidia
Copy link
Copy Markdown

@jeffnvidia jeffnvidia commented Feb 23, 2026

What does this PR do ?

Enable AG/RS overlap process-group plumbing in Megatron-Bridge distributed initialization by adding support for create_all_gather_group, including expert-parallel all-gather groups when EP > 1, and wiring these groups into ProcessGroupCollection.

Changelog

  • Add create_all_gather_group: bool = False to DistributedInitConfig in src/megatron/bridge/training/config.py.
  • Update src/megatron/bridge/training/initialize.py to thread create_all_gather_group through distributed init.
  • In the decentralized (HyperCommGrid) path, create and attach:
    • dp_cp_ag (DP+CP all-gather group)
    • expt_dp_ag (expert-DP all-gather group when EP is enabled)
  • In the centralized path, call parallel_state.create_all_gather_groups(...) after model-parallel init and attach returned AG groups to ProcessGroupCollection.
  • Keep default behavior unchanged when create_all_gather_group=False (no AG group creation, existing behavior preserved).
  • No intended packaging/build-system changes are part of this PR (pyproject.toml local testing override is excluded).

GitHub Actions CI

See the CI section in the Contributing doc for how to trigger the CI. A Nvidia developer will need to approve and trigger the CI for external contributors.

Before your PR is "Ready for review"

Pre checks:

  • Make sure you read and followed Contributor guidelines
  • Did you write any new necessary tests?
  • Did you add or update any necessary documentation?
  • Does the PR affect components that are optional to install? No
    • Reviewer: Does the PR have correct import guards for all optional libraries? N/A

If you haven't finished some of the above items you can still open "Draft" PR.

Additional Information

Summary by CodeRabbit

  • New Features
    • Added a configuration option to enable optimized process groups for improved distributed training performance.

@copy-pr-bot
Copy link
Copy Markdown

copy-pr-bot bot commented Feb 23, 2026

This pull request requires additional validation before any workflows can run on NVIDIA's runners.

Pull request vetters can view their responsibilities here.

Contributors can view more details about this message here.

@coderabbitai
Copy link
Copy Markdown
Contributor

coderabbitai bot commented Feb 23, 2026

No actionable comments were generated in the recent review. 🎉

ℹ️ Recent review info

Configuration used: Path: .coderabbit.yaml

Review profile: CHILL

Plan: Pro

📥 Commits

Reviewing files that changed from the base of the PR and between 9ecee6c and 7f4e63f.

📒 Files selected for processing (2)
  • src/megatron/bridge/training/config.py
  • src/megatron/bridge/training/initialize.py

📝 Walkthrough

Walkthrough

This change adds a new configuration field create_all_gather_group to DistributedInitConfig and implements conditional creation and integration of all-gather process groups into the ProcessGroupCollection during distributed initialization, with support for both decentralized and centralized initialization paths.

Changes

Cohort / File(s) Summary
Configuration
src/megatron/bridge/training/config.py
Added new boolean field create_all_gather_group to DistributedInitConfig class with documentation explaining its role in enabling AG/RS overlap optimization for FSDP training.
Initialization Logic
src/megatron/bridge/training/initialize.py
Added create_all_gather_group parameter to _create_pg_collection function to conditionally build and attach dp_cp_ag and expt_dp_ag process groups. Parameter propagated through _initialize_distributed with separate code paths for decentralized and centralized initialization modes.

Sequence Diagram(s)

sequenceDiagram
    participant Config as DistributedInitConfig
    participant InitDist as _initialize_distributed
    participant CreatePG as _create_pg_collection
    participant ParallelState as parallel_state
    participant PGCollection as ProcessGroupCollection

    Config->>InitDist: create_all_gather_group flag
    InitDist->>CreatePG: create_all_gather_group=True
    
    alt Decentralized Path
        CreatePG->>ParallelState: create_all_gather_groups()
        ParallelState->>PGCollection: populate AG groups
    else Centralized Path
        CreatePG->>CreatePG: build dp_cp_ag_pg
        CreatePG->>CreatePG: build expt_dp_ag_pg (if expert parallel)
        CreatePG->>PGCollection: attach dp_cp_ag & expt_dp_ag
    end
    
    PGCollection-->>InitDist: return with AG groups
Loading

Estimated code review effort

🎯 3 (Moderate) | ⏱️ ~25 minutes

Possibly related PRs

Suggested reviewers

  • ananthsub
🚥 Pre-merge checks | ✅ 3 | ❌ 1

❌ Failed checks (1 warning)

Check name Status Explanation Resolution
Test Results For Major Changes ⚠️ Warning PR lacks documented test results and performance metrics for major AG/RS overlap optimization feature in distributed initialization. Add comprehensive unit tests for create_all_gather_group, quantified performance benchmarks with specific metrics, hardware configuration details, convergence validation, and complete documentation updates.
✅ Passed checks (3 passed)
Check name Status Explanation
Description Check ✅ Passed Check skipped - CodeRabbit’s high-level summary is enabled.
Title check ✅ Passed The title clearly and concisely summarizes the main change: adding support for AG/RS overlap in distributed initialization configuration.
Docstring Coverage ✅ Passed Docstring coverage is 100.00% which is sufficient. The required threshold is 80.00%.

✏️ Tip: You can configure your own custom pre-merge checks in the settings.

✨ Finishing Touches
  • 📝 Generate docstrings (stacked PR)
  • 📝 Generate docstrings (commit on current branch)
🧪 Generate unit tests (beta)
  • Create PR with unit tests
  • Post copyable unit tests in a comment

Tip

Issue Planner is now in beta. Read the docs and try it out! Share your feedback on Discord.


Thanks for using CodeRabbit! It's free for OSS, and your support helps us grow. If you like it, consider giving us a shout-out.

❤️ Share

Comment @coderabbitai help to get the list of available commands and usage tips.

yaoyu-33
yaoyu-33 previously approved these changes Mar 2, 2026
@yaoyu-33
Copy link
Copy Markdown
Contributor

yaoyu-33 commented Mar 2, 2026

/ok to test 7f4e63f

Signed-off-by: jeffnvidia <jmahou@nvidia.com>
Signed-off-by: jeffnvidia <jmahou@nvidia.com>
Made-with: Cursor
@jeffnvidia jeffnvidia force-pushed the ag_rs_overlap_process_group_collection branch from dc30cd1 to ea44a27 Compare March 10, 2026 09:25
Comment on lines +674 to +687
# Create AG groups if requested
if dist_config.create_all_gather_group:
for_expert_parallelism = (getattr(model_config, "expert_model_parallel_size", 1) or 1) > 1
dp_cp_ag, expt_dp_ag = parallel_state.create_all_gather_groups(
for_expert_parallelism=for_expert_parallelism,
timeout=datetime.timedelta(minutes=dist_config.distributed_timeout_minutes),
nccl_comm_cfgs=None, # Could use dist_config.nccl_communicator_config_path if needed
)
# Get ProcessGroupCollection and populate with AG groups
pg_collection = ProcessGroupCollection.use_mpu_process_groups()
pg_collection.dp_cp_ag = dp_cp_ag
if expt_dp_ag is not None:
pg_collection.expt_dp_ag = expt_dp_ag
return pg_collection
Copy link
Copy Markdown
Contributor

@cspades cspades Mar 16, 2026

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Note to self: Adding this to the PG collection will be passed to the Megatron-FSDP DDP FullyShardedDataParallel adapter which then passes it to the FSDPDistIndex / MegatronFSDP class API.

if create_all_gather_group:
# Create regular DP all-gather group with same ranks as dp_cp_pg
# Use HyperCommGrid to enumerate ranks for dp-cp groups
dp_cp_rank_lists = grid._gen_rank_enum(["dp", "cp"])
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

why "grid.create_pg(...)" not working? ideally shouldn't use internal api here, bit risky.

Copy link
Copy Markdown
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

From what I can tell, grid.create_pg(["dp", "cp"]) can't be used here because it was already called on line 415 to create dp_cp_pg.

If I understand correctly, calling it again would raise a KeyError — create_pg keys by dimension names only (line 151 of hyper_comm_grid.py), so any second call with ["dp", "cp"] would collide with the existing "dp-cp" entry, regardless of group_desc or pg_options.

The AG group needs the same ranks but as an independent NCCL communicator, so I used _gen_rank_enum to get the rank lists and passed them to new_subgroups_by_enumeration directly. Same situation for the expert grid on line 516.

That said, I'm not super familiar with the HyperCommGrid internals — happy to refactor if there's a preferred way to create a second PG with the same rank topology ?

@jeffnvidia
Copy link
Copy Markdown
Author

@cspades @yaoyu-33 , let me know if I need to do additional changes, thanks !

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

3 participants